{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from modeling_llama import LlamaForCausalLM\n",
    "from transformers import AutoTokenizer\n",
    "from decoding import clip_input, infer_input_ids\n",
    "from rouge_score import rouge_scorer\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.nn.Linear.reset_parameters = lambda x: None\n",
    "model = LlamaForCausalLM.from_pretrained('../llama-2-13b/', torch_dtype=torch.bfloat16)\n",
    "tokenizer = AutoTokenizer.from_pretrained('../llama-2-13b/')\n",
    "device='cuda:6'\n",
    "model = model.to(device).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -*- encoding:utf-8 -*-\n",
    "import transformers\n",
    "print(torch.version.cuda)\n",
    "print(torch.__version__) \n",
    "print(transformers.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_attn_skip_layer_id_set, _mlp_skip_layer_id_set =model.get_skip_layers()\n",
    "print(_attn_skip_layer_id_set, _mlp_skip_layer_id_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed=42\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "n_shot = 1\n",
    "task_name = 'cnndm'\n",
    "prompt_shots = ''\n",
    "if task_name == 'xsum':\n",
    "    data = load_dataset('xsum', split='test').shuffle(seed=seed).select(range(1000))\n",
    "    shots = load_dataset('xsum',split='train').shuffle(seed=seed).select(range(n_shot))\n",
    "    prompt_keys=['document','summary']\n",
    "elif task_name == 'cnndm':\n",
    "    data = load_dataset('cnn_dailymail', name='3.0.0', split='test') .shuffle(seed=seed).select(range(1000))\n",
    "    shots = load_dataset('cnn_dailymail', name='3.0.0', split='train').shuffle(seed=seed).select(range(n_shot))\n",
    "    prompt_keys=['article','highlights']\n",
    "for i in range(n_shot):\n",
    "    prompt = 'Article: ' + shots[i][prompt_keys[0]] + '\\nSummary: ' + shots[i][prompt_keys[1]].replace('\\n', '') + '\\n'\n",
    "    prompt_shots += prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_from_disk\n",
    "n_shot = 1\n",
    "task_name = 'cnndm'\n",
    "prompt_shots = ''\n",
    "if task_name == 'xsum':\n",
    "    data =  load_from_disk('../xsum/test/1000')\n",
    "    shots = load_from_disk('../xsum/shots')\n",
    "    prompt_keys=['document','summary']\n",
    "elif task_name == 'cnndm':\n",
    "    data =  load_from_disk('../cnndm/test/1000')\n",
    "    shots = load_from_disk('../cnndm/shots')\n",
    "    prompt_keys=['article','highlights']\n",
    "for i in range(n_shot):\n",
    "    prompt = 'Article: ' + shots[i][prompt_keys[0]] + '\\nSummary: ' + shots[i][prompt_keys[1]].replace('\\n', '') + '\\n'\n",
    "    prompt_shots += prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(prompt_shots)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rouge=rouge_scorer.RougeScorer(['rouge2'], use_stemmer=True)\n",
    "main_metrics = {'rouge2_base':[], \n",
    "                'rouge2_essg1':[], 'rouge2_essg2':[], 'rouge2_essg3':[], 'rouge2_essg4':[], \n",
    "                'rouge2_essg_autoth':[],\n",
    "                'time_base':[], \n",
    "                'time_essg1':[], 'time_essg2':[], 'time_essg3':[], 'time_essg4':[], \n",
    "                'time_essg_autoth':[],\n",
    "                'token_time_base':[], \n",
    "                'token_time_essg1':[], 'token_time_essg2':[], 'token_time_essg3':[], 'token_time_essg4':[], \n",
    "                'token_time_essg_autoth':[],\n",
    "                'matchness_essg1':[],'num_drafted_tokens_essg1':[],\n",
    "                'matchness_essg2':[],'num_drafted_tokens_essg2':[],\n",
    "                'matchness_essg3':[],'num_drafted_tokens_essg3':[],\n",
    "                'matchness_essg4':[],'num_drafted_tokens_essg4':[],\n",
    "                'matchness_essg_autoth':[],'num_drafted_tokens_essg_autoth':[]}\n",
    "with open('CNNDM_llama13b_1shot_bayesian-prompt-new_maxtoken512_maxstep12_essg_autoth0.60-st1-mat0.90-var1e-2-mom1-0.50-mom2-0.90_abla-th2468_env4-33-1_torch_1-13-1_cuda11.6_dec1.txt', 'w') as f:\n",
    "    for i,x in enumerate(data):\n",
    "        input_ids = clip_input(tokenizer, x, task_name, max_new_tokens=512,prompt_shots=prompt_shots)\n",
    "        \n",
    "        if i == 0:\n",
    "            th_stop_draft_essg1 = 0.20\n",
    "            th_stop_draft_essg2 = 0.40\n",
    "            th_stop_draft_essg3 = 0.60\n",
    "            th_stop_draft_essg4 = 0.80\n",
    "            th_stop_draft_essg_autoth  = 0.60\n",
    "        else:\n",
    "            th_stop_draft_essg1 = result_essg1['th_stop_draft']\n",
    "            th_stop_draft_essg2 = result_essg2['th_stop_draft']\n",
    "            th_stop_draft_essg3 = result_essg3['th_stop_draft']\n",
    "            th_stop_draft_essg4 = result_essg4['th_stop_draft']\n",
    "            th_stop_draft_essg_autoth = result_essg_autoth['th_stop_draft']\n",
    "        f.write('essg th1: {:.4f}, essg th2: {:.4f}, essg th3: {:.4f}, essg th4: {:.4f}, essg autoth: {:.4f} \\n'.format(\n",
    "        th_stop_draft_essg1, th_stop_draft_essg2, th_stop_draft_essg3, th_stop_draft_essg4, th_stop_draft_essg_autoth))\n",
    "        result_base = infer_input_ids(model, tokenizer, input_ids, generate_fn='base',\n",
    "                    max_new_tokens=512, do_sample=False, early_stop=True)\n",
    "        result_essg1 = infer_input_ids(model, tokenizer, input_ids, generate_fn='essg', \n",
    "                    max_new_tokens=512, early_stop=True, max_step_draft=12, \n",
    "                    th_stop_draft=th_stop_draft_essg1, auto_th_stop_draft=False,\n",
    "                    do_sample=False, do_sample_draft=False)\n",
    "        result_essg2 = infer_input_ids(model, tokenizer, input_ids, generate_fn='essg', \n",
    "                    max_new_tokens=512, early_stop=True, max_step_draft=12, \n",
    "                    th_stop_draft=th_stop_draft_essg2, auto_th_stop_draft=False,\n",
    "                    do_sample=False, do_sample_draft=False)\n",
    "        result_essg3 = infer_input_ids(model, tokenizer, input_ids, generate_fn='essg', \n",
    "                    max_new_tokens=512, early_stop=True, max_step_draft=12, \n",
    "                    th_stop_draft=th_stop_draft_essg3,  auto_th_stop_draft=False,\n",
    "                    do_sample=False, do_sample_draft=False)\n",
    "        result_essg4 = infer_input_ids(model, tokenizer, input_ids, generate_fn='essg', \n",
    "                    max_new_tokens=512, early_stop=True, max_step_draft=12, \n",
    "                    th_stop_draft=th_stop_draft_essg4,  auto_th_stop_draft=False,\n",
    "                    do_sample=False, do_sample_draft=False)\n",
    "        result_essg_autoth = infer_input_ids(model, tokenizer, input_ids, generate_fn='essg', \n",
    "                    max_new_tokens=512, early_stop=True, max_step_draft=12, \n",
    "                    th_stop_draft=th_stop_draft_essg_autoth, auto_th_stop_draft=True, auto_parameters=[1,0.50,0.90,1e-2,0.90],\n",
    "                    do_sample=False, do_sample_draft=False)\n",
    "\n",
    "        if len(result_base['completion']) < 5 or ('.....' in result_base['completion'][:5]):\n",
    "            print(\"too short, skip\")\n",
    "            continue\n",
    "        \n",
    "        if task_name == 'xsum':\n",
    "            references = x['summary']\n",
    "        elif task_name =='cnndm':\n",
    "            references = x['highlights']\n",
    "            \n",
    "        results = [\n",
    "            ('base', result_base),\n",
    "            ('essg1', result_essg1),\n",
    "            ('essg2', result_essg2),\n",
    "            ('essg3', result_essg3),\n",
    "            ('essg4', result_essg4),\n",
    "            ('essg_autoth', result_essg_autoth)\n",
    "        ]\n",
    "\n",
    "        for key, result in results:\n",
    "            main_metrics['time_' + key].append(result['time'])\n",
    "            main_metrics['token_time_' + key].append(result['time'] / result['generate_ids'].shape[1])\n",
    "            if key != 'base':\n",
    "                main_metrics['matchness_' + key].append(result['matchness'])\n",
    "                main_metrics['num_drafted_tokens_' + key].append(result['num_drafted_tokens'])\n",
    "            clip_pred = result['completion'].find(\"\\nArticle:\")\n",
    "            if clip_pred > 0:\n",
    "                prediction = result['completion'][:clip_pred]\n",
    "            else:\n",
    "                prediction = result['completion']\n",
    "            rouge_score = rouge.score(prediction, references)\n",
    "            main_metrics['rouge2_' + key].append(rouge_score['rouge2'].fmeasure)\n",
    "\n",
    "        metric = {\n",
    "            'mean rouge-2 base':np.mean(main_metrics['rouge2_base']),\n",
    "            f'mean rouge-2 essg th {th_stop_draft_essg1}':np.mean(main_metrics['rouge2_essg1']),\n",
    "            f'mean rouge-2 essg th {th_stop_draft_essg2}':np.mean(main_metrics['rouge2_essg2']),\n",
    "            f'mean rouge-2 essg th {th_stop_draft_essg3}':np.mean(main_metrics['rouge2_essg3']),\n",
    "            f'mean rouge-2 essg th {th_stop_draft_essg4}':np.mean(main_metrics['rouge2_essg4']),\n",
    "            'mean rouge-2 essg autoth':np.mean(main_metrics['rouge2_essg_autoth']),\n",
    "            'mean time base':np.mean(main_metrics['time_base']),\n",
    "            f'mean time essg th {th_stop_draft_essg1}':np.mean(main_metrics['time_essg1']),\n",
    "            f'mean time essg th {th_stop_draft_essg2}':np.mean(main_metrics['time_essg2']),\n",
    "            f'mean time essg th {th_stop_draft_essg3}':np.mean(main_metrics['time_essg3']),\n",
    "            f'mean time essg th {th_stop_draft_essg4}':np.mean(main_metrics['time_essg4']),\n",
    "            'mean time essg autoth':np.mean(main_metrics['time_essg_autoth']),\n",
    "            f'E2E mean speed up essg th {th_stop_draft_essg1}':np.mean(main_metrics['time_base'])/np.mean(main_metrics['time_essg1']),\n",
    "            f'E2E mean speed up essg th {th_stop_draft_essg2}':np.mean(main_metrics['time_base'])/np.mean(main_metrics['time_essg2']),\n",
    "            f'E2E mean speed up essg th {th_stop_draft_essg3}':np.mean(main_metrics['time_base'])/np.mean(main_metrics['time_essg3']),\n",
    "            f'E2E mean speed up essg th {th_stop_draft_essg4}':np.mean(main_metrics['time_base'])/np.mean(main_metrics['time_essg4']),\n",
    "            'E2E mean speed up essg autoth':np.mean(main_metrics['time_base'])/np.mean(main_metrics['time_essg_autoth']),\n",
    "            'mean token time base':np.mean(main_metrics['token_time_base']),\n",
    "            f'mean token time essg th {th_stop_draft_essg1}':np.mean(main_metrics['token_time_essg1']),\n",
    "            f'mean token time essg th {th_stop_draft_essg2}':np.mean(main_metrics['token_time_essg2']),\n",
    "            f'mean token time essg th {th_stop_draft_essg3}':np.mean(main_metrics['token_time_essg3']),\n",
    "            f'mean token time essg th {th_stop_draft_essg4}':np.mean(main_metrics['token_time_essg4']),\n",
    "            'mean token time essg autoth':np.mean(main_metrics['token_time_essg_autoth']),  \n",
    "            f'E2E mean token speed up essg th {th_stop_draft_essg1}':np.mean(main_metrics['token_time_base'])/np.mean(main_metrics['token_time_essg1']),\n",
    "            f'E2E mean token speed up essg th {th_stop_draft_essg2}':np.mean(main_metrics['token_time_base'])/np.mean(main_metrics['token_time_essg2']),\n",
    "            f'E2E mean token speed up essg th {th_stop_draft_essg3}':np.mean(main_metrics['token_time_base'])/np.mean(main_metrics['token_time_essg3']),\n",
    "            f'E2E mean token speed up essg th {th_stop_draft_essg4}':np.mean(main_metrics['token_time_base'])/np.mean(main_metrics['token_time_essg4']),\n",
    "            'E2E mean token speed up essg autoth':np.mean(main_metrics['token_time_base'])/np.mean(main_metrics['token_time_essg_autoth']),          \n",
    "            f'mean matchness essg th {th_stop_draft_essg1}':np.mean(main_metrics['matchness_essg1']),\n",
    "            f'mean matchness essg th {th_stop_draft_essg2}':np.mean(main_metrics['matchness_essg2']),\n",
    "            f'mean matchness essg th {th_stop_draft_essg3}':np.mean(main_metrics['matchness_essg3']),\n",
    "            f'mean matchness essg th {th_stop_draft_essg4}':np.mean(main_metrics['matchness_essg4']),\n",
    "            'mean matchness essg autoth':np.mean(main_metrics['matchness_essg_autoth']),\n",
    "            f'mean num_drafted_tokens essg th {th_stop_draft_essg1}':np.mean(main_metrics['num_drafted_tokens_essg1']),\n",
    "            f'mean num_drafted_tokens essg th {th_stop_draft_essg2}':np.mean(main_metrics['num_drafted_tokens_essg2']),\n",
    "            f'mean num_drafted_tokens essg th {th_stop_draft_essg3}':np.mean(main_metrics['num_drafted_tokens_essg3']),\n",
    "            f'mean num_drafted_tokens essg th {th_stop_draft_essg4}':np.mean(main_metrics['num_drafted_tokens_essg4']),\n",
    "            'mean num_drafted_tokens essg autoth':np.mean(main_metrics['num_drafted_tokens_essg_autoth']),\n",
    "        }\n",
    "        for key, value in metric.items():\n",
    "            if isinstance(value, float):\n",
    "                metric[key] = f\"{value:.4f}\"\n",
    "\n",
    "        # print(f'data {i},{metric}')\n",
    "        f.write(f'data {i},{metric} \\n')\n",
    "        f.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "skipgpt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
